#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import print_function
import sys
import os
import random #for random sort
import datetime #for date manipulations
import re #for regexes
import codecs
from collections import OrderedDict, defaultdict

import rbql_utils

__RBQLMP__import_expression


PY3 = sys.version_info[0] == 3


def iteritems6(x):
    if PY3:
        return x.items()
    return x.iteritems()


def str6(obj):
    #we have to use this function because str() for python2.7 tries to ascii-encode unicode strings
    if PY3 and isinstance(obj, str):
        return obj
    if not PY3 and isinstance(obj, basestring):
        return obj
    return str(obj)


input_delim = '__RBQLMP__input_delim'
join_delim = '__RBQLMP__join_delim'
output_delim = '__RBQLMP__output_delim'
csv_encoding='__RBQLMP__csv_encoding'

def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)


class BadFieldError(Exception):
    def __init__(self, bad_idx):
        self.bad_idx = bad_idx

class RbqlRuntimeError(Exception):
    pass


def safe_get(record, idx):
    try:
        return record[idx - 1]
    except IndexError as e:
        raise BadFieldError(idx - 1)


def safe_set(record, idx, value):
    try:
        record[idx - 1] = value
    except IndexError as e:
        raise BadFieldError(idx - 1)

enable_monocolumn_csv_ux_optimization_hack = True

module_was_used_failsafe = False

# For warnings:
join_fields_info = dict()
input_fields_info = dict()
null_value_in_output = False
delim_in_simple_output = False # Need refactoring to detect line number
output_switch_to_csv = None # Need this as warning because user could use TSV files and CSV would be unexpected. Also this is a marker for editor to use CSV syntax for output instead of  monocolumn.
utf8_bom_removed = False
defective_csv_line_in_input = 0
defective_csv_line_in_join = 0

# Aggregators:
aggregation_stage = 0
aggr_init_counter = 0
functional_aggregators = list()

writer = None

NU = 0 # NU - Num Updated. Alternative variables: NW (Num Where) - Not Practical. NW (Num Written) - Impossible to implement.


class Marker(object):
    def __init__(self, marker_id, value):
        self.marker_id = marker_id
        self.value = value

    def __str__(self):
        raise TypeError('Marker')


def init_aggregator(generator_name, val):
    global aggregation_stage
    global aggr_init_counter
    aggregation_stage = 1
    assert aggr_init_counter == len(functional_aggregators)
    functional_aggregators.append(generator_name())
    res = Marker(aggr_init_counter, val)
    aggr_init_counter += 1
    return res


def MIN(val):
    return init_aggregator(rbql_utils.MinAggregator, val) if aggregation_stage < 2 else val


def MAX(val):
    return init_aggregator(rbql_utils.MaxAggregator, val) if aggregation_stage < 2 else val


def COUNT(val):
    return init_aggregator(rbql_utils.CountAggregator, 1) if aggregation_stage < 2 else 1


def SUM(val):
    return init_aggregator(rbql_utils.SumAggregator, val) if aggregation_stage < 2 else val


def AVG(val):
    return init_aggregator(rbql_utils.AvgAggregator, val) if aggregation_stage < 2 else val


def VARIANCE(val):
    return init_aggregator(rbql_utils.VarianceAggregator, val) if aggregation_stage < 2 else val


def MEDIAN(val):
    return init_aggregator(rbql_utils.MedianAggregator, val) if aggregation_stage < 2 else val


def quote_field(src, delim):
    if src.find('"') != -1 or src.find(delim) != -1:
        escaped = src.replace('"', '""')
        escaped = '"{}"'.format(escaped)
        return escaped
    return src


def quoted_join(fields, delim):
    return delim.join([quote_field(str6(f), delim) for f in fields])


def mono_join(fields, delim):
    if enable_monocolumn_csv_ux_optimization_hack and '__RBQLMP__input_policy' == 'monocolumn':
        global output_switch_to_csv
        if output_switch_to_csv is None:
            output_switch_to_csv = (len(fields) > 1)
        assert output_switch_to_csv == (len(fields) > 1), 'Monocolumn optimization logic failure'
        if output_switch_to_csv:
            return quoted_join(fields, ',')
        else:
            return fields[0]
    if len(fields) > 1:
        raise RbqlRuntimeError('Unable to use "Monocolumn" output format: some records have more than one field')
    return fields[0]


def simple_join(fields, delim):
    res = delim.join([str6(f) for f in fields])
    num_fields = res.count(delim)
    if num_fields + 1 != len(fields):
        global delim_in_simple_output
        delim_in_simple_output = True
    return res


if '__RBQLMP__output_policy' == 'simple':
    output_join = simple_join
elif '__RBQLMP__output_policy' == 'quoted':
    output_join = quoted_join
elif '__RBQLMP__output_policy' == 'monocolumn':
    output_join = mono_join
else:
    raise RuntimeError('unknown output csv policy')


if '__RBQLMP__input_policy' == 'simple':
    input_split = lambda line: (line.split(input_delim), 0)
elif '__RBQLMP__input_policy' == 'quoted':
    input_split = lambda line: rbql_utils.split_quoted_str(line, input_delim)
elif '__RBQLMP__input_policy' == 'monocolumn':
    input_split = lambda line: ([line], 0)
else:
    raise RuntimeError('unknown input csv policy')


if '__RBQLMP__join_policy' == '':
    join_split = lambda line: (None, None)
elif '__RBQLMP__join_policy' == 'simple':
    join_split = lambda line: (line.split(join_delim), 0)
elif '__RBQLMP__join_policy' == 'quoted':
    join_split = lambda line: rbql_utils.split_quoted_str(line, join_delim)
elif '__RBQLMP__join_policy' == 'monocolumn':
    join_split = lambda line: ([line], 0)
else:
    raise RuntimeError('unknown join csv policy')


def add_to_set(dst_set, value):
    len_before = len(dst_set)
    dst_set.add(value)
    return len_before != len(dst_set)


class SimpleWriter(object):
    def __init__(self, dst):
        self.dst = dst
        self.NW = 0

    def write(self, record):
        if __RBQLMP__top_count is not None and self.NW >= __RBQLMP__top_count:
            return False
        self.dst.write(record)
        self.dst.write('\n')
        self.NW += 1
        return True

    def finish(self):
        pass


class UniqWriter(object):
    def __init__(self, dst):
        self.dst = dst
        self.seen = set()
        self.NW = 0

    def write(self, record):
        if __RBQLMP__top_count is not None and self.NW >= __RBQLMP__top_count:
            return False
        if not add_to_set(self.seen, record):
            return True
        self.dst.write(record)
        self.dst.write('\n')
        self.NW += 1
        return True

    def finish(self):
        pass


class UniqCountWriter(object):
    # we can't optimize this for ordered result set:
    # they might be ordered not by output records, but by some other expression
    def __init__(self, dst):
        self.dst = dst
        self.records = OrderedDict()

    def write(self, record):
        if record in self.records:
            self.records[record] += 1
        else:
            self.records[record] = 1
        return True

    def finish(self):
        NW = 0
        for record, count in iteritems6(self.records):
            if __RBQLMP__top_count is not None and NW >= __RBQLMP__top_count:
                break
            new_record = '{}{}{}\n'.format(count, output_delim, record)
            self.dst.write(new_record)
            NW += 1


class SortedWriter(object):
    def __init__(self, subwriter):
        self.subwriter = subwriter
        self.unsorted_entries = list()

    def write(self, sort_key_value, record):
        self.unsorted_entries.append((sort_key_value, record))
        return True

    def finish(self):
        sorted_entries = sorted(self.unsorted_entries, key=lambda x: x[0])
        if __RBQLMP__reverse_flag:
            sorted_entries.reverse()
        for e in sorted_entries:
            if not self.subwriter.write(e[1]):
                break
        self.subwriter.finish()


class AggregateWriter(object):
    def __init__(self, subwriter):
        self.subwriter = subwriter
        self.aggregators = []
        self.aggregation_keys = set()

    def finish(self):
        all_keys = sorted(list(self.aggregation_keys))
        for key in all_keys:
            out_fields = [ag.get_final(key) for ag in self.aggregators]
            out_record = output_join(out_fields, output_delim)
            if not self.subwriter.write(out_record):
                break
        self.subwriter.finish()


def remove_utf8_bom(line, assumed_source_encoding):
    if assumed_source_encoding == 'latin-1' and len(line) >= 3 and line[:3] == '\xef\xbb\xbf':
        return line[3:]
    if assumed_source_encoding == 'utf-8' and len(line) >= 1 and line[0] == u'\ufeff':
        return line[1:]
    return line


def read_join_table(join_table_path):
    global join_fields_info
    global defective_csv_line_in_join
    global utf8_bom_removed
    fields_max_len = 0
    if not os.path.isfile(join_table_path):
        raise RbqlRuntimeError('Table B: ' + join_table_path + ' is not accessible')
    result = defaultdict(list)
    with codecs.open(join_table_path, encoding=csv_encoding) as src_text:
        for il, line in enumerate(rbql_utils.rows(src_text), 1):
            if il == 1:
                clean_line = remove_utf8_bom(line, csv_encoding)
                if clean_line != line:
                    line = clean_line
                    utf8_bom_removed = True
            line = line.rstrip('\r\n')
            bfields, warning = join_split(line)
            if warning and not defective_csv_line_in_join:
                defective_csv_line_in_join = il
            num_fields = len(bfields)
            fields_max_len = max(fields_max_len, num_fields)
            if num_fields not in join_fields_info:
                join_fields_info[num_fields] = il
            try:
                key = __RBQLMP__rhs_join_var
            except BadFieldError as e:
                bad_idx = e.bad_idx
                raise RbqlRuntimeError('No "b' + str(bad_idx + 1) + '" column at line: ' + str(il) + ' in "B" table')
            result[key].append(bfields)
    return (result, fields_max_len)


def none_joiner(path):
    return None


class InnerJoiner(object):
    def __init__(self, join_table_path):
        self.join_data, self.fields_max_len = read_join_table(join_table_path)

    def get_rhs(self, lhs_key):
        return self.join_data.get(lhs_key, [])


class LeftJoiner(object):
    def __init__(self, join_table_path):
        self.join_data, self.fields_max_len = read_join_table(join_table_path)
        self.null_record = [[None] * self.fields_max_len]

    def get_rhs(self, lhs_key):
        return self.join_data.get(lhs_key, self.null_record)


class StrictLeftJoiner(object):
    def __init__(self, join_table_path):
        self.join_data, self.fields_max_len = read_join_table(join_table_path)

    def get_rhs(self, lhs_key):
        result = self.join_data.get(lhs_key, [])
        if len(result) != 1:
            raise RbqlRuntimeError('In "STRICT LEFT JOIN" each key in A must have exactly one match in B. Bad A key: "' + lhs_key + '"')
        return result


def replace_none_values(fields):
    global null_value_in_output
    i = 0
    while i < len(fields):
        if fields[i] is None:
            fields[i] = ''
            null_value_in_output = True
        i += 1


def create_warnings_report():
    warnings = dict()
    if defective_csv_line_in_join != 0:
        warnings['defective_csv_line_in_join'] = defective_csv_line_in_join
    if defective_csv_line_in_input != 0:
        warnings['defective_csv_line_in_input'] = defective_csv_line_in_input
    if null_value_in_output:
        warnings['null_value_in_output'] = True
    if delim_in_simple_output:
        warnings['delim_in_simple_output'] = True
    if output_switch_to_csv:
        warnings['output_switch_to_csv'] = True
    if utf8_bom_removed:
        warnings['utf8_bom_removed'] = True
    if len(input_fields_info) > 1:
        warnings['input_fields_info'] = input_fields_info
    if len(join_fields_info) > 1:
        warnings['join_fields_info'] = join_fields_info
    if not len(warnings):
        return None
    return warnings


def main():
    warnings = rb_transform(sys.stdin, sys.stdout)
    if warnings is not None:
        for k, v in warnings.items():
            eprint(k, v)


def process_update(NR, NF, afields, rhs_records):
    if rhs_records is not None and len(rhs_records) > 1:
        raise RbqlRuntimeError('More than one record in UPDATE query matched A-key in join table B')
    bfields = None
    if rhs_records is not None and len(rhs_records) == 1:
        bfields = rhs_records[0]
    if (rhs_records is None or len(rhs_records) == 1) and (__RBQLMP__where_expression):
        global NU
        NU += 1
        __RBQLMP__update_statements
    out_fields = afields
    replace_none_values(out_fields)
    out_record = output_join(out_fields, output_delim)
    return writer.write(out_record)


def select_simple(NR, NF, afields, bfields, out_fields):
    replace_none_values(out_fields)
    out_record = output_join(out_fields, output_delim)
    if __RBQLMP__sort_flag:
        sort_key_value = (__RBQLMP__sort_key_expression)
        if not writer.write(sort_key_value, out_record):
            return False
    else:
        if not writer.write(out_record):
            return False
    return True


def select_aggregated(NR, NF, afields, bfields, transparent_values):
    global aggregation_stage
    key = __RBQLMP__aggregation_key_expression
    if key is not None:
        key = output_join(key, output_delim)
    if aggregation_stage == 1:
        global writer
        if type(writer) is not SimpleWriter:
            raise RbqlRuntimeError('Unable to use "ORDER BY" or "DISTINCT" keywords in aggregate query')
        writer = AggregateWriter(writer)
        for i, trans_value in enumerate(transparent_values):
            if isinstance(trans_value, Marker):
                writer.aggregators.append(functional_aggregators[trans_value.marker_id])
                writer.aggregators[-1].increment(key, trans_value.value)
            else:
                writer.aggregators.append(rbql_utils.SubkeyChecker())
                writer.aggregators[-1].increment(key, trans_value)
        aggregation_stage = 2
    else:
        for i, trans_value in enumerate(transparent_values):
            writer.aggregators[i].increment(key, trans_value)
    writer.aggregation_keys.add(key)


def process_select(NR, NF, afields, rhs_records):
    if rhs_records is None:
        rhs_records = [None]
    for bfields in rhs_records:
        if bfields is None:
            star_fields = afields
        else:
            star_fields = afields + bfields
        if not (__RBQLMP__where_expression):
            continue
        out_fields = __RBQLMP__select_expression
        if aggregation_stage > 0:
            select_aggregated(NR, NF, afields, bfields, out_fields)
        else:
            if not select_simple(NR, NF, afields, bfields, out_fields):
                return False
    return True


def rb_transform(source, destination):
    global module_was_used_failsafe
    assert not module_was_used_failsafe
    module_was_used_failsafe = True

    global input_fields_info
    global defective_csv_line_in_input
    global utf8_bom_removed
    global writer
    writer = __RBQLMP__writer_type(destination)
    if __RBQLMP__sort_flag:
        writer = SortedWriter(writer)
    joiner = __RBQLMP__joiner_type('__RBQLMP__rhs_table_path')
    for NR, line in enumerate(rbql_utils.rows(source), 1):
        if NR == 1:
            clean_line = remove_utf8_bom(line, csv_encoding)
            if clean_line != line:
                line = clean_line
                utf8_bom_removed = True
        line = line.rstrip('\r\n')
        afields, warning = input_split(line)
        if warning and not defective_csv_line_in_input:
            defective_csv_line_in_input = NR
        NF = len(afields)
        if NF not in input_fields_info:
            input_fields_info[NF] = NR
        try:
            rhs_records = None
            if joiner is not None:
                rhs_records = joiner.get_rhs(__RBQLMP__lhs_join_var)
            # __RBQLMP__process_function below is either "process_select" or "process_update"
            if not __RBQLMP__process_function(NR, NF, afields, rhs_records):
                break
        except BadFieldError as e:
            bad_idx = e.bad_idx
            raise RbqlRuntimeError('No "a' + str(bad_idx + 1) + '" column at line: ' + str(NR))
        except Exception as e:
            raise RbqlRuntimeError('Error at line: ' + str(NR) + ', Details: ' + str(e))
    writer.finish()
    return create_warnings_report()


if __name__ == '__main__':
    main()


